【pytorch学习笔记1】 您所在的位置:网站首页 no classes怎么读 【pytorch学习笔记1】

【pytorch学习笔记1】

2024-01-13 03:58| 来源: 网络整理| 查看: 265

数据读取Dataset与Dataloader 前言官方通用的数据加载器文件目录存储格式主要函数所有代码代码部分讲解官方通用的数据加载器收获 图片数据集(标签在图片名称上)构建自己的Dataset(重要)data列表构建总结 待续

前言

在pytorch学习这一块总是断断续续,完成大作业所写的代码再次回首已经完全看不懂了。所以我决定把学习过程中遇到的一些问题和知识总结出来,希望能取得一些进步吧。本人完全菜鸟,写这些笔记的主要目的是督促自己坚持学习下去,笔记中可能出现比较夸张的错误,恳请各位大佬谅解。 在数据集读取学习过程中遇到了很多很多很多困难,目前也只是对图片数据集(标签信息在图片名称上)读取稍微明白了一些,关于txt文件、CSV文件等。尤其是mat文件还是不是很明白怎么去处理。希望这篇笔记未来能把处理这些数据集的代码都写出来。

官方通用的数据加载器

以花卉分类为例

文件目录存储格式

在这里插入图片描述

主要函数

将图片数据集存储成指定上述存储格式后调用后面的函数即可实现

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=data_transform["train"]) 所有代码 data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), "val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224) transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path assert os.path.exists(image_path), "{} path does not exist.".format(image_path) train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=data_transform["train"]) train_num = len(train_dataset) # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} flower_list = train_dataset.class_to_idx cla_dict = dict((val, key) for key, val in flower_list.items()) # write dict into json file json_str = json.dumps(cla_dict, indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) batch_size = 32 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers print('Using {} dataloader workers every process'.format(nw)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw) validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=data_transform["val"]) print("using {} images for training, {} images for validation.".format(train_num, val_num)) 代码部分讲解 创建列表 # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} flower_list = train_dataset.class_to_idx cla_dict = dict((val, key) for key, val in flower_list.items()) # write dict into json file json_str = json.dumps(cla_dict, indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str)

将文件夹目录下的所有文件夹转换成字典,包含文件夹名称和对应的数字标签0,1,2,3,4

flower_list = train_dataset.class_to_idx # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}

颠倒key和val

cla_dict = dict((val, key) for key, val in flower_list.items()) # { 0:'daisy', 1:'dandelion', 2:'roses', 3:'sunflower', 4:'tulips'}

剩下的代码就是将其写入json文件

使用 json_path = './class_indices.json' json_file = open(json_path, "r") class_indict = json.load(json_file) print("class: {}".format(class_indict[str(i)]))#i=0,1,2,3,4

如i=0时,由{ 0:‘daisy’, 1:‘dandelion’, 2:‘roses’, 3:‘sunflower’, 4:‘tulips’}可知应该输出daisy

print("class: {}".format(class_indict[str(0)]))

在这里插入图片描述

官方通用的数据加载器收获

官方官方通用的数据加载器中可以利用class_to_idx获取字典格式最终存储成json文件

def find_classes(path): classes = [d.name for d in os.scandir(path) if d.is_dir()] classes.sort() class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return class_to_idx

使用 在这里插入图片描述

image_path = os.path.join(os.getcwd(), "imagedata") # data set path bear_list = find_classes(image_path) print(bear_list)

结果

{'ballone': 0, 'ballthree': 1, 'balltwo': 2, 'innerone': 3, 'innerthree': 4, 'innertwo': 5, 'normal': 6, 'outerone': 7, 'outerthree': 8, 'outertwo': 9}

最后见上一节代码部分讲解即可保存json文件,方便之后预测使用。

图片数据集(标签在图片名称上)

这里使用完成大作业的代码进行学习,代码为轴承故障10分类问题,数据集为mat文件经过一系列操作转换为小波时频图jpg文件(转换过程的代码有时间争取我也总结一下)。图片文件存储地址以及图片如下

图片文件存储地址 小波时频图

构建自己的Dataset(重要) class MyDataset(Dataset): def __init__(self, data, transform, loder): self.data = data self.transform = transform self.loader = loder def __getitem__(self, item): img, label = self.data[item] img = self.loader(img) img = self.transform(img) return img, label def __len__(self): return len(self.data)

编程学的不好,术语不太会说,我就挑着代码里我不懂的地方简单说一下。

data: data是一个列表,格式为 [图片的地址(对应img),图片的标签(对应label)]。data列表具体怎么构建的下面会介绍。loder: 读取图片函数,这里调用的自己编写的Myloader函数 def Myloader(path): return Image.open(path).convert('RGB') Transform: 对数据进行预处理

具体使用如下

train = MyDataset(train_data, transform=transform_train, loder=Myloader) test = MyDataset(test_data, transform=transform_test, loder=Myloader) __ getitem__ 在DataLoader 送入torch中进行训练时,会自动调用数据集类的__getitem__()方法 train = MyDataset(train_data, transform=transform_train, loder=Myloader) train_data = DataLoader(dataset=train, batch_size=10, shuffle=True, num_workers=0) #截取DataLoader中的一段函数 def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,......

Dataset[T_co]从而调用类中的方法__getitem__,从而return img, label(见上面代码),从而得到每张照片(根据T_co的值)的img和label便于后续的训练代码

getitem返回方式两种都可以(如下) import torch from torch.utils.data import Dataset,DataLoader class MyDataset1(Dataset): def __init__(self): self.data = torch.tensor([[1,2,3],[2,3,4],[4,5,6]]) self.lable = torch.LongTensor([1,1,0,0]) def __getitem__(self,index): data = (self.data[index],self.lable[index]) return data def __len__(self): return len(self.data) class MyDataset2(Dataset): def __init__(self): self.data = torch.tensor([[1,2,3],[2,3,4],[4,5,6]]) self.lable = torch.LongTensor([1,1,0,0]) def __getitem__(self,index): return self.data[index], self.lable[index] def __len__(self): return len(self.data) mydataset1=MyDataset1() mydataset2=MyDataset2() mydataloder1 = DataLoader(dataset=mydataset1,batch_size=1) mydataloder2 = DataLoader(dataset=mydataset2,batch_size=1) for i,(data,label) in enumerate(mydataloder1): print(data,label) for i,(data,label) in enumerate(mydataloder2): print(data,label)

在这里插入图片描述

注意!train中for使用的不同 for batch_idx, (data, target) in enumerate(train_loader): for step, data in enumerate(train_bar): images, labels = data data列表构建

我的理解主要是干了这么一个事情,就是要把数据处理一下,创建了一个列表,里面装着图片和他的标签,所以步骤就是从文件中读取图片1转换成想要的格式,再读取图片1的标签,然后打包存进去。data里面就是[图片1,图片1的标签],[图片2,图片2的标签]…然后后续再索引需要的就可以了。现在的目的首先就是如何创建出data列表

# 得到一个包含路径与标签的列表 #标签通过find_label函数获取 #lens代表数据长度,也就是data中有几个数据,比如有2个就是data=[[图片1,图片1标签],[图片2,图片2标签]] def init_process(path, lens): data = [] name = find_label(path) for i in range(lens[0], lens[1]): data.append([path % i, name]) return data # 将图片名称中的 字母标签 转换成 0,1,2,3,4,5,6,7,8,9标签 #举例 #这里str输入的是图片路径,如path1 = r'C:\Users\lenovo\Desktop\Modern Signal Processing\xiaobo_CNN\imagedata\normal\normal_%d.jpg' #可以看出图片的英文标签是normal,经过函数find_label转换,标签变为0 def find_label(str): first, last = 0, 0 for i in range(len(str) - 1, -1, -1): if str[i] == '%' and str[i - 1] == '_': last = i - 1 if (str[i] == 'n' or str[i] == 'b' or str[i] == 'i' or str[i] == 'o') and str[i - 1] == '\\': first = i break name = str[first:last] #print(name) if name == 'normal': return 0 elif name == 'ballone': return 1 elif name == 'balltwo': return 2 elif name == 'ballthree': return 3 elif name == 'innerone': return 4 elif name == 'innertwo': return 5 elif name == 'innerthree': return 6 elif name == 'outerone': return 7 elif name == 'outertwo': return 8 elif name == 'outerthree': return 9

具体实现过程

#data1列表中包含了200张图片以及其对应的标签数据 #1.首先转换成data=[[图片1,图片1标签],[图片2,图片2标签]....]的格式,以data1举例 path1 = data1 = init_process(path1, [1,200]) #2.拼接所有的data形成训练数据集 train_data = data1[1:140] + data2[1:140]+ data3[1:140]+ data4[1:140]+ data5[1:140]+ data6[1:140]+ data7[1:140]+ data8[1:140]+ data9[1:140]+ data10[1:140] #3.由此可以创建出data列表,作为Dataset的输入,使用MyDataset函数 train = MyDataset(train_data, transform=transform_train, loder=Myloader) #4.使用DataLoader函数 train_data = DataLoader(dataset=train, batch_size=10, shuffle=True, num_workers=0) 总结

放上整个代码(路径就不放了,存的乱七八糟)

def Myloader(path): return Image.open(path).convert('RGB') def init_process(path, lens): data = [] name = find_label(path) for i in range(lens[0], lens[1]): data.append([path % i, name]) return data class MyDataset(Dataset): def __init__(self, data, transform, loder): self.data = data self.transform = transform self.loader = loder def __getitem__(self, item): img, label = self.data[item] img = self.loader(img) img = self.transform(img) return img, label def __len__(self): return len(self.data) def find_label(str): first, last = 0, 0 for i in range(len(str) - 1, -1, -1): if str[i] == '%' and str[i - 1] == '_': last = i - 1 if (str[i] == 'n' or str[i] == 'b' or str[i] == 'i' or str[i] == 'o') and str[i - 1] == '\\': first = i break name = str[first:last] #print(name) if name == 'normal': return 0 elif name == 'ballone': return 1 elif name == 'balltwo': return 2 elif name == 'ballthree': return 3 elif name == 'innerone': return 4 elif name == 'innertwo': return 5 elif name == 'innerthree': return 6 elif name == 'outerone': return 7 elif name == 'outertwo': return 8 elif name == 'outerthree': return 9 def load_data(): transform_train = transforms.Compose([ #transforms.RandomResizedCrop(224),#对图片尺寸做一个缩放切割 #transforms.RandomHorizontalFlip(),#图像一半的概率翻转,一半的概率不翻转 transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) transform_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) path1 = data1 = init_process(path1, [1,200]) path2 = data2 = init_process(path2, [1, 200]) path3 = data3 = init_process(path3, [1, 200]) path4 = data4 = init_process(path4, [1, 200]) path5 = data5 = init_process(path5, [1, 200]) path6 = data6 = init_process(path6, [1, 200]) path7 = data7 = init_process(path7, [1, 200]) path8 = data8 = init_process(path8, [1, 200]) path9 = data9 = init_process(path9, [1, 200]) path10 = data10 = init_process(path10, [1, 200]) # 800个训练 train_data = data1[1:140] + data2[1:140]+ data3[1:140]+ data4[1:140]+ data5[1:140]+ data6[1:140]+ data7[1:140]+ data8[1:140]+ data9[1:140]+ data10[1:140] train = MyDataset(train_data, transform=transform_train, loder=Myloader) # 200个测试 test_data = data1[141:180] + data2[161:180]+ data3[161:180]+ data4[161:180]+ data5[161:180]+ data6[161:180]+ data7[161:180]+ data8[161:180]+ data9[161:180]+ data10[161:180] test = MyDataset(test_data, transform=transform_test, loder=Myloader) train_data = DataLoader(dataset=train, batch_size=10, shuffle=True, num_workers=0) test_data = DataLoader(dataset=test, batch_size=1, shuffle=True, num_workers=0) return train_data, test_data 待续

mat文件 txt文件 CSV文件



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有